Project Description: Urban Traffic Incident Detection Using Real-Time Sensor Data¶
Objective
The objective of this project is to develop a classification model to detect and predict the severity of urban traffic incidents using real-time sensor data. This model aims to help in proactive traffic management and reduce congestion by providing timely alerts for traffic incidents.
import matplotlib.pyplot as plt
keywords = [
"Traffic Management",
"Safety",
"Resource Allocation",
"Data Decisions",
"Economic Impact",
"Environmental Benefits",
"Public Satisfaction",
"Smart City"
]
fig, ax = plt.subplots(figsize=(8, 4))
labels = keywords
sizes = [1] * len(keywords)
colors = ['skyblue', 'lightcoral', 'lightgreen', 'lightyellow', 'lightsalmon', 'lightpink', 'lightblue', 'lightgrey']
explode = [0.1] * len(keywords)
ax.pie(sizes, explode=explode, labels=labels, colors=colors, shadow=True, startangle=140)
ax.axis('equal')
plt.title('Why is it useful to do this analysis?', fontsize=14, fontweight='bold')
plt.show()
Features
There are 12,316 rows and 32 columns related to observations and traffic incidents, respectively.
Dataset Column Descriptions¶
- Time: The time of the accident.
- Day_of_week: The day of the week the accident occurred.
- Age_band_of_driver: Age group of the driver.
- Sex_of_driver: Gender of the driver.
- Educational_level: Educational level of the driver.
- Vehicle_driver_relation: Relationship between the vehicle owner and the driver.
- Driving_experience: Years of driving experience of the driver.
- Type_of_vehicle: Type of vehicle involved in the accident.
- Owner_of_vehicle: Ownership status of the vehicle.
- Service_year_of_vehicle: Number of years the vehicle has been in service.
- Defect_of_vehicle: Any defect present in the vehicle.
- Area_accident_occured: Area where the accident occurred.
- Lanes_or_Medians: Presence and type of lanes or medians on the road.
- Road_allignment: Alignment of the road at the accident site.
- Types_of_Junction: Type of junction where the accident occurred.
- Road_surface_type: Type of road surface.
- Road_surface_conditions: Condition of the road surface at the time of the accident.
- Light_conditions: Light conditions at the time of the accident.
- Weather_conditions: Weather conditions at the time of the accident.
- Type_of_collision: Type of collision in the accident.
- Number_of_vehicles_involved: Number of vehicles involved in the accident.
- Number_of_casualties: Number of casualties in the accident.
- Vehicle_movement: Movement of the vehicle at the time of the accident.
- Casualty_class: Class of the casualty (e.g., pedestrian, passenger).
- Sex_of_casualty: Gender of the casualty.
- Age_band_of_casualty: Age group of the casualty.
- Casualty_severity: Severity of the casualty's injuries.
- Work_of_casuality: Employment status of the casualty.
- Fitness_of_casuality: Fitness level of the casualty.
- Pedestrian_movement: Movement of the pedestrian at the time of the accident.
- Cause_of_accident: Cause of the accident.
- Accident_severity: Severity of the accident.
The project will involve data cleaning, exploratory data analysis (EDA), feature engineering, and the development of a machine learning model based on statistical analysis to classify the severity of traffic incidents based on the provided features.
Import Library¶
import pandas as pd
from matplotlib import pyplot as plt
import seaborn as sns
import joblib
Data¶
file_path = '~/Downloads/RTA Dataset.csv'
data = pd.read_csv(file_path)
data.head()
| Time | Day_of_week | Age_band_of_driver | Sex_of_driver | Educational_level | Vehicle_driver_relation | Driving_experience | Type_of_vehicle | Owner_of_vehicle | Service_year_of_vehicle | ... | Vehicle_movement | Casualty_class | Sex_of_casualty | Age_band_of_casualty | Casualty_severity | Work_of_casuality | Fitness_of_casuality | Pedestrian_movement | Cause_of_accident | Accident_severity | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 17:02:00 | Monday | 18-30 | Male | Above high school | Employee | 1-2yr | Automobile | Owner | Above 10yr | ... | Going straight | na | na | na | na | NaN | NaN | Not a Pedestrian | Moving Backward | Slight Injury |
| 1 | 17:02:00 | Monday | 31-50 | Male | Junior high school | Employee | Above 10yr | Public (> 45 seats) | Owner | 5-10yrs | ... | Going straight | na | na | na | na | NaN | NaN | Not a Pedestrian | Overtaking | Slight Injury |
| 2 | 17:02:00 | Monday | 18-30 | Male | Junior high school | Employee | 1-2yr | Lorry (41?100Q) | Owner | NaN | ... | Going straight | Driver or rider | Male | 31-50 | 3 | Driver | NaN | Not a Pedestrian | Changing lane to the left | Serious Injury |
| 3 | 1:06:00 | Sunday | 18-30 | Male | Junior high school | Employee | 5-10yr | Public (> 45 seats) | Governmental | NaN | ... | Going straight | Pedestrian | Female | 18-30 | 3 | Driver | Normal | Not a Pedestrian | Changing lane to the right | Slight Injury |
| 4 | 1:06:00 | Sunday | 18-30 | Male | Junior high school | Employee | 2-5yr | NaN | Owner | 5-10yrs | ... | Going straight | na | na | na | na | NaN | NaN | Not a Pedestrian | Overtaking | Slight Injury |
5 rows × 32 columns
data_summary = pd.DataFrame({
'Data Type': data.dtypes,
'Missing Values': data.isnull().sum(),
'Unique Values': data.nunique()
})
data_summary = data_summary.sort_values(by='Missing Values', ascending=False)
data_summary
| Data Type | Missing Values | Unique Values | |
|---|---|---|---|
| Defect_of_vehicle | object | 4427 | 3 |
| Service_year_of_vehicle | object | 3928 | 6 |
| Work_of_casuality | object | 3198 | 7 |
| Fitness_of_casuality | object | 2635 | 5 |
| Type_of_vehicle | object | 950 | 17 |
| Types_of_Junction | object | 887 | 8 |
| Driving_experience | object | 829 | 7 |
| Educational_level | object | 741 | 7 |
| Vehicle_driver_relation | object | 579 | 4 |
| Owner_of_vehicle | object | 482 | 4 |
| Lanes_or_Medians | object | 385 | 7 |
| Vehicle_movement | object | 308 | 13 |
| Area_accident_occured | object | 239 | 14 |
| Road_surface_type | object | 172 | 5 |
| Type_of_collision | object | 155 | 10 |
| Road_allignment | object | 142 | 9 |
| Casualty_class | object | 0 | 4 |
| Pedestrian_movement | object | 0 | 9 |
| Cause_of_accident | object | 0 | 20 |
| Casualty_severity | object | 0 | 4 |
| Age_band_of_casualty | object | 0 | 6 |
| Sex_of_casualty | object | 0 | 3 |
| Time | object | 0 | 1074 |
| Road_surface_conditions | object | 0 | 4 |
| Number_of_casualties | int64 | 0 | 8 |
| Number_of_vehicles_involved | int64 | 0 | 6 |
| Weather_conditions | object | 0 | 9 |
| Light_conditions | object | 0 | 4 |
| Day_of_week | object | 0 | 7 |
| Sex_of_driver | object | 0 | 3 |
| Age_band_of_driver | object | 0 | 5 |
| Accident_severity | object | 0 | 3 |
data_summary.groupby('Data Type').size().plot(kind='barh', color=sns.palettes.mpl_palette('Dark2'))
plt.gca().spines[['top', 'right',]].set_visible(False)
figsize = (12, 1.2 * len(data_summary['Data Type'].unique()))
plt.figure(figsize=figsize)
sns.violinplot(data_summary, x='Missing Values', hue='Data Type', inner='stick', palette='Dark2')
sns.despine(top=True, right=True, bottom=True, left=True)
plt.figure(figsize=(8,4))
colours = ['#34495E', 'seagreen']
sns.heatmap(data.isnull(), cmap=sns.color_palette(colours))
<Axes: >
missing_values = data.isnull().sum()
print(missing_values[missing_values > 0])
Educational_level 741 Vehicle_driver_relation 579 Driving_experience 829 Type_of_vehicle 950 Owner_of_vehicle 482 Service_year_of_vehicle 3928 Defect_of_vehicle 4427 Area_accident_occured 239 Lanes_or_Medians 385 Road_allignment 142 Types_of_Junction 887 Road_surface_type 172 Type_of_collision 155 Vehicle_movement 308 Work_of_casuality 3198 Fitness_of_casuality 2635 dtype: int64
Based on this information, there are four columns with a significant number of missing values: Service_Year_of_Vehicle, Defect_of_Vehicle, Work_of_Casuality, and Fitness_of_Casuality. The next step is to analyze their impact on the target variable, which is the severity of the accident.
from scipy.stats import chi2_contingency
import warnings
data['Accident_severity'] = pd.Categorical(data['Accident_severity'], categories=['Slight Injury', 'Serious Injury', 'Fatal injury'])
columns_to_analyze = ['Service_year_of_vehicle', 'Defect_of_vehicle', 'Work_of_casuality', 'Fitness_of_casuality']
plt.figure(figsize=(5, 5))
unique_severities = data['Accident_severity'].unique()
palette = sns.color_palette("Set2", len(unique_severities))
# Suppress FutureWarnings from seaborn
warnings.filterwarnings('ignore', category=FutureWarning)
# Plot the distribution of 'Accident_severity' with and without missing values
fig, axs = plt.subplots(len(columns_to_analyze), 2, figsize=(10, 5 * len(columns_to_analyze)))
for i, column in enumerate(columns_to_analyze):
sns.countplot(x='Accident_severity', data=data[data[column].isnull()], ax=axs[i, 0], palette=palette)
axs[i, 0].set_title(f'Accident Severity with Missing {column}')
axs[i, 0].legend(labels=unique_severities, loc='upper right')
sns.countplot(x='Accident_severity', data=data[data[column].notnull()], ax=axs[i, 1], palette=palette)
axs[i, 1].set_title(f'Accident Severity without Missing {column}')
axs[i, 1].legend(labels=unique_severities, loc='upper right')
plt.tight_layout()
plt.show()
# Perform chi-squared tests
chi2_results = {}
for column in columns_to_analyze:
contingency_table = pd.crosstab(data['Accident_severity'], data[column].isnull())
chi2, p, dof, expected = chi2_contingency(contingency_table)
chi2_results[column] = {'chi2': chi2, 'p': p}
chi2_results
<Figure size 500x500 with 0 Axes>
{'Service_year_of_vehicle': {'chi2': 1.722305458919597,
'p': 0.422674571956952},
'Defect_of_vehicle': {'chi2': 2.8583048690672266, 'p': 0.2395118382021787},
'Work_of_casuality': {'chi2': 0.07546190767365601, 'p': 0.9629719899606989},
'Fitness_of_casuality': {'chi2': 0.37391140712037346, 'p': 0.82948047860372}}
Chi-squared test compares the observed frequency of 'Accident_severity' levels within the groups where data is missing and where it is not missing. Since all p-values are greater than 0.05, we conclude that the missingness in these columns is not significantly associated with the levels of 'Accident_severity'. This implies that the missing values in these columns do not affect the distribution of accident severity and can be imputed without introducing bias related to the severity of accidents.
# Impute with the most frequent value for 'Service_year_of_vehicle' and 'Work_of_casuality'
most_frequent_service_year = data['Service_year_of_vehicle'].mode()[0]
data['Service_year_of_vehicle'].fillna(most_frequent_service_year, inplace=True)
most_frequent_work_of_casuality = data['Work_of_casuality'].mode()[0]
data['Work_of_casuality'].fillna(most_frequent_work_of_casuality, inplace=True)
# Impute with a placeholder for 'Defect_of_vehicle' and 'Fitness_of_casuality'
data['Defect_of_vehicle'].fillna('Unknown', inplace=True)
data['Fitness_of_casuality'].fillna('Unknown', inplace=True)
# Handle missing values
# Drop columns with too many missing values
#columns_to_drop = ['Service_year_of_vehicle', 'Defect_of_vehicle', 'Work_of_casuality', 'Fitness_of_casuality']
#data.drop(columns=columns_to_drop, inplace=True)
data.shape
(12316, 32)
# Impute missing values in other columns with the mode
columns_to_impute = data.columns[data.isnull().sum() > 0]
for column in columns_to_impute:
mode_value = data[column].mode()[0]
data[column].fillna(mode_value, inplace=True)
# Verify that there are no missing values left
missing_values_cleaned = data.isnull().sum()
print(missing_values_cleaned[missing_values_cleaned > 0])
Series([], dtype: int64)
Explatory Data Analysis¶
fig, axes = plt.subplots(1, 2, figsize=(10, 4))
sns.histplot(data=data, x='Number_of_vehicles_involved', hue='Accident_severity', multiple="stack", ax=axes[0])
axes[0].set_title('Distribution of Number of vehicles involved')
axes[0].set_xlabel('Number of vehicles involved')
axes[0].set_ylabel('Count')
sns.histplot(data=data, x='Number_of_casualties', hue='Accident_severity', multiple="stack", ax=axes[1])
axes[1].set_title('Distribution of Number of casualties')
axes[1].set_xlabel('Number of casualties')
axes[1].set_ylabel('Count')
plt.tight_layout()
plt.show()
- The most frequent number of vehicles involved in traffic incidents is around 2, indicating a peak in the distribution. The majority of incidents involve between 1 and 3 vehicles, with very few incidents involving more than 3 vehicles. The distribution is right-skewed, meaning there are fewer incidents with a higher number of vehicles involved.
- For the number of casualties, the most frequent number is 1, representing the highest peak in the distribution. Most incidents result in 1 or 2 casualties, with very few resulting in more than 2. This distribution is also right-skewed, indicating that higher casualty numbers are less common.
The histograms show that most traffic incidents involve a small number of vehicles and casualties, which can help focus modeling efforts on these common scenarios. Since the distributions are skewed, one can transform these features (e.g., using log transformation) to reduce skewness and improve model performance.
import numpy as np
# Apply log transformation in place
data['Number_of_vehicles_involved'] = np.log(data['Number_of_vehicles_involved'] + 1)
data['Number_of_casualties'] = np.log(data['Number_of_casualties'] + 1)
fig, axes = plt.subplots(1, 2, figsize=(10, 4))
sns.histplot(data=data, x='Number_of_vehicles_involved', hue='Accident_severity', multiple="stack", ax=axes[0])
axes[0].set_title('Distribution of Number of Vehicles Involved (Log Transformed)')
axes[0].set_xlabel('Log(Number of Vehicles Involved)')
axes[0].set_ylabel('Count')
sns.histplot(data=data, x='Number_of_casualties', hue='Accident_severity', multiple="stack", ax=axes[1])
axes[1].set_title('Distribution of Number of Casualties (Log Transformed)')
axes[1].set_xlabel('Log(Number of Casualties)')
axes[1].set_ylabel('Count')
plt.tight_layout()
plt.show()
hue_variable = 'Accident_severity'
categorical_columns = data.select_dtypes(include=['object', 'category']).columns
categorical_columns = [col for col in categorical_columns if col != 'Time']
for col in categorical_columns:
if col != hue_variable: # Avoid using the hue variable as a category itself
plt.figure(figsize=(6, 4))
sns.countplot(y=data[col], hue=data[hue_variable], palette="viridis")
plt.title(f'Distribution of {col} by {hue_variable}')
plt.xlabel('Count')
plt.ylabel(col)
plt.legend(title=hue_variable)
plt.show()
Summary:
Day of the Week: Accidents are relatively evenly distributed throughout the week, with a slight increase on weekends.
Age of Driver: Young and middle-aged drivers (18-50) are more frequently involved in accidents.
Gender: Male drivers are significantly more involved in accidents than female drivers.
Education Level: Drivers with a junior high school education are most commonly involved in accidents, which may reflect socio-economic influences.
Vehicle Driver Relation: Employees are more frequently involved in accidents than vehicle owners, suggesting higher exposure or job-related driving risks.
Driving Experience: Drivers with 5-10 years of experience are involved in the most accidents, indicating that moderate to extensive experience does not necessarily equate to lower accident risk.
Type of Vehicle: Automobiles and certain types of lorries are more frequently involved in accidents, highlighting the need to consider vehicle-specific safety measures.
Owner of Vehicle: Privately owned vehicles are most commonly involved in accidents, likely reflecting their higher numbers on the road.
Accident Location: Non-specific 'Other' areas, residential, and office areas are key zones for accidents, indicating a need for targeted safety measures in these areas.
Lane and Median Type: Undivided two-way roads and roads with broken lines are higher risk, suggesting the need for better lane management and separation measures.
Road Alignment: Straight, flat roads are not inherently safe and require driver attention and speed management measures.
Junction Type: Y-shaped junctions and areas without junctions are high-risk zones, requiring better design and traffic control measures to reduce accidents.
Road Surface Type: Asphalt roads see the highest number of accidents, indicating that their prevalence might be the main factor.
Road Surface Conditions: Dry roads are the most common condition for accidents, but wet or damp conditions also pose significant risks.
Light Conditions: Daylight conditions have the highest number of accidents, likely due to higher traffic volumes, but night driving still presents dangers.
Weather Conditions: Normal weather conditions see the most accidents, with rainy conditions showing increased risk, underscoring the need for careful driving in adverse weather.
Type of Collision: Collisions with parked vehicles and between moving vehicles are the most frequent, suggesting the need for better parking management and driver awareness.
Vehicle Movement: Most accidents occur during straightforward driving, emphasizing the importance of vigilance even in seemingly low-risk scenarios.
Casualty Class: Drivers and riders are the primary victims in accidents, indicating the need for enhanced driver protection measures. Pedestrians and passengers also require targeted safety strategies.
Sex of Casualty: Male casualties are more common, suggesting a focus on male-targeted safety campaigns to address higher exposure or risk behaviors.
Age Band of Casualty: Young adults (18-30) and middle-aged individuals (31-50) are the most affected by accidents, indicating a need for targeted safety measures for these age groups.
Casualty Severity: Non-fatal but serious injuries (level 3) are common, highlighting the need for improved safety features in vehicles and road infrastructure to reduce injury severity.
Pedestrian Movement: Most significant injuries happen with non-pedestrians.
Cause of Accident: Common causes like failure to give priority and improper lane changes point to areas where traffic enforcement and driver education can make a significant impact.
Service_year_of_vehicle: The number of service years is unknown for significant proportion of slight injuries and some serious injuries.
Defect_of_vehicle: There is no defect in the vehicles that majority are slight injuries, with a notable number of serious injuries and a few fatal injuries.
Work_of_casusality: Drivers have predominantly slight injuries, with a small proportion of serious injuries.
# Check for Outliers
numerical_columns = data.select_dtypes(include=['float64', 'int64']).columns
categorical_columns = data.select_dtypes(include=['object', 'category']).columns
fig, axes = plt.subplots(len(numerical_columns)//2, 2, figsize=(12, len(numerical_columns)*2))
axes = axes.flatten()
for i, col in enumerate(numerical_columns):
sns.boxplot(data=data[col], ax=axes[i], color='skyblue')
axes[i].set_title(f'Boxplot of {col}')
axes[i].set_xlabel(col)
plt.tight_layout()
plt.show()
# Analyze Correlations
corr = data[numerical_columns].corr()
plt.figure(figsize=(5, 5))
sns.heatmap(corr, annot=True, cmap='coolwarm', fmt=".2f")
plt.title('Correlation Matrix')
plt.show()
!pip install ydata-profiling > /dev/null 2>&1
from pandas_profiling import ProfileReport
profile = ProfileReport(data, title="Traffic Incident Data Profile Report", explorative=True)
profile.to_notebook_iframe()
/var/folders/2w/vykj3cf54pv9n2rkprsn2ts00000gn/T/ipykernel_43427/841217043.py:1: DeprecationWarning: `import pandas_profiling` is going to be deprecated by April 1st. Please use `import ydata_profiling` instead. from pandas_profiling import ProfileReport
# Drop highly correlated features based on the analysis
features_to_drop = ['Casualty_class', 'Road_surface_conditions', 'Age_band_of_casualty','Sex_of_casualty']
data.drop(columns=features_to_drop, inplace=True)
profile = ProfileReport(data, title="Traffic Incident Data Profile Report", explorative=True)
profile.to_notebook_iframe()
Feature Engineering¶
Time-based features¶
# Parse 'Time' column and extract the hour
data['Time'] = pd.to_datetime(data['Time'], format='%H:%M:%S', errors='coerce')
data['hour'] = data['Time'].dt.hour
data['hour'] = data['hour'].fillna(-1)
data.head()
| Time | Day_of_week | Age_band_of_driver | Sex_of_driver | Educational_level | Vehicle_driver_relation | Driving_experience | Type_of_vehicle | Owner_of_vehicle | Service_year_of_vehicle | ... | Number_of_vehicles_involved | Number_of_casualties | Vehicle_movement | Casualty_severity | Work_of_casuality | Fitness_of_casuality | Pedestrian_movement | Cause_of_accident | Accident_severity | hour | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 1900-01-01 17:02:00 | Monday | 18-30 | Male | Above high school | Employee | 1-2yr | Automobile | Owner | Above 10yr | ... | 1.098612 | 1.098612 | Going straight | na | Driver | Unknown | Not a Pedestrian | Moving Backward | Slight Injury | 17 |
| 1 | 1900-01-01 17:02:00 | Monday | 31-50 | Male | Junior high school | Employee | Above 10yr | Public (> 45 seats) | Owner | 5-10yrs | ... | 1.098612 | 1.098612 | Going straight | na | Driver | Unknown | Not a Pedestrian | Overtaking | Slight Injury | 17 |
| 2 | 1900-01-01 17:02:00 | Monday | 18-30 | Male | Junior high school | Employee | 1-2yr | Lorry (41?100Q) | Owner | Unknown | ... | 1.098612 | 1.098612 | Going straight | 3 | Driver | Unknown | Not a Pedestrian | Changing lane to the left | Serious Injury | 17 |
| 3 | 1900-01-01 01:06:00 | Sunday | 18-30 | Male | Junior high school | Employee | 5-10yr | Public (> 45 seats) | Governmental | Unknown | ... | 1.098612 | 1.098612 | Going straight | 3 | Driver | Normal | Not a Pedestrian | Changing lane to the right | Slight Injury | 1 |
| 4 | 1900-01-01 01:06:00 | Sunday | 18-30 | Male | Junior high school | Employee | 2-5yr | Automobile | Owner | 5-10yrs | ... | 1.098612 | 1.098612 | Going straight | na | Driver | Unknown | Not a Pedestrian | Overtaking | Slight Injury | 1 |
5 rows × 29 columns
print(data[['Time', 'hour']].head())
Time hour 0 1900-01-01 17:02:00 17 1 1900-01-01 17:02:00 17 2 1900-01-01 17:02:00 17 3 1900-01-01 01:06:00 1 4 1900-01-01 01:06:00 1
Encoding categorical variables¶
data.columns
Index(['Time', 'Day_of_week', 'Age_band_of_driver', 'Sex_of_driver',
'Educational_level', 'Vehicle_driver_relation', 'Driving_experience',
'Type_of_vehicle', 'Owner_of_vehicle', 'Service_year_of_vehicle',
'Defect_of_vehicle', 'Area_accident_occured', 'Lanes_or_Medians',
'Road_allignment', 'Types_of_Junction', 'Road_surface_type',
'Light_conditions', 'Weather_conditions', 'Type_of_collision',
'Number_of_vehicles_involved', 'Number_of_casualties',
'Vehicle_movement', 'Casualty_severity', 'Work_of_casuality',
'Fitness_of_casuality', 'Pedestrian_movement', 'Cause_of_accident',
'Accident_severity', 'hour'],
dtype='object')
# One-hot encode categorical variables
categorical_columns = data.select_dtypes(include=['object', 'category']).columns
data_encoded = pd.get_dummies(data, columns=categorical_columns, drop_first=True)
# Define target and features
severity_columns = [col for col in data_encoded.columns if 'Accident_severity_' in col]
X = data_encoded.drop(severity_columns + ['Time'], axis=1) # Ensure 'Time' is not in the features
y = data_encoded[severity_columns]
# Convert target variable to single categorical column
y = y.idxmax(axis=1).apply(lambda x: x.replace('Accident_severity_', ''))
data_encoded.shape
(12316, 171)
data_encoded.columns
Index(['Time', 'Number_of_vehicles_involved', 'Number_of_casualties', 'hour',
'Day_of_week_Monday', 'Day_of_week_Saturday', 'Day_of_week_Sunday',
'Day_of_week_Thursday', 'Day_of_week_Tuesday', 'Day_of_week_Wednesday',
...
'Cause_of_accident_No priority to vehicle', 'Cause_of_accident_Other',
'Cause_of_accident_Overloading', 'Cause_of_accident_Overspeed',
'Cause_of_accident_Overtaking', 'Cause_of_accident_Overturning',
'Cause_of_accident_Turnover', 'Cause_of_accident_Unknown',
'Accident_severity_Serious Injury', 'Accident_severity_Fatal injury'],
dtype='object', length=171)
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.feature_selection import SelectFromModel
from imblearn.over_sampling import SMOTE
# Handle imbalanced data using SMOTE
smote = SMOTE(random_state=42)
X_resampled, y_resampled = smote.fit_resample(X, y)
# Convert target labels to numerical format using LabelEncoder
label_encoder = LabelEncoder()
y_resampled_encoded = label_encoder.fit_transform(y_resampled)
# Feature selection using Random Forest
rf = RandomForestClassifier(random_state=42)
rf.fit(X_resampled, y_resampled_encoded)
selector = SelectFromModel(rf, threshold="mean", prefit=True)
support_mask = selector.get_support()
# Select features from the original DataFrame using the support mask
X_selected = X_resampled.loc[:, support_mask]
# Check the shape of the selected features and their names
print("Original shape:", X_resampled.shape)
print("Selected shape:", X_selected.shape)
Original shape: (24316, 168) Selected shape: (24316, 49)
X_selected.head()
| Number_of_vehicles_involved | Number_of_casualties | hour | Day_of_week_Saturday | Day_of_week_Sunday | Day_of_week_Thursday | Age_band_of_driver_31-50 | Age_band_of_driver_Over 51 | Educational_level_Elementary school | Educational_level_High school | ... | Vehicle_movement_Going straight | Casualty_severity_3 | Casualty_severity_na | Work_of_casuality_Self-employed | Fitness_of_casuality_Normal | Fitness_of_casuality_Unknown | Cause_of_accident_Changing lane to the right | Cause_of_accident_Driving carelessly | Cause_of_accident_Moving Backward | Cause_of_accident_No distancing | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 1.098612 | 1.098612 | 17 | False | False | False | False | False | False | False | ... | True | False | True | False | False | True | False | False | True | False |
| 1 | 1.098612 | 1.098612 | 17 | False | False | False | True | False | False | False | ... | True | False | True | False | False | True | False | False | False | False |
| 2 | 1.098612 | 1.098612 | 17 | False | False | False | False | False | False | False | ... | True | True | False | False | False | True | False | False | False | False |
| 3 | 1.098612 | 1.098612 | 1 | False | True | False | False | False | False | False | ... | True | True | False | False | True | False | True | False | False | False |
| 4 | 1.098612 | 1.098612 | 1 | False | True | False | False | False | False | False | ... | True | False | True | False | False | True | False | False | False | False |
5 rows × 49 columns
from sklearn.preprocessing import StandardScaler
# Standardize the features
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X_selected)
y_resampled_encoded
array([1, 1, 1, ..., 0, 0, 0])
Model selection¶
# Split data into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(X_scaled, y_resampled_encoded, test_size=0.2, random_state=42)
X_scaled_df=pd.DataFrame(X_scaled)
X_test_df = pd.DataFrame(X_test, columns=X_scaled_df.columns)
y_test_df = pd.Series(y_test, name='target')
X_test_df.to_csv('X_test.csv', index=False)
y_test_df.to_csv('y_test.csv', index=False)
Naive Bayes Classifier¶
from sklearn.naive_bayes import GaussianNB
from sklearn.metrics import classification_report, accuracy_score
# Initialize the model
nb_model = GaussianNB()
# Fit the model on the training data
nb_model.fit(X_train, y_train)
# Predict and evaluate
nb_pred = nb_model.predict(X_test)
print("Naive Bayes Classifier")
print(classification_report(y_test, nb_pred))
print(f'Accuracy: {accuracy_score(y_test, nb_pred)}')
Naive Bayes Classifier
precision recall f1-score support
0 0.88 0.93 0.91 2431
1 0.92 0.88 0.90 2433
accuracy 0.90 4864
macro avg 0.90 0.90 0.90 4864
weighted avg 0.90 0.90 0.90 4864
Accuracy: 0.9031661184210527
Logistic Regression¶
from sklearn.linear_model import LogisticRegression
# Train Logistic Regression
lr_model = LogisticRegression(max_iter=1000, random_state=42)
lr_model.fit(X_train, y_train)
# Predict and evaluate
lr_pred = lr_model.predict(X_test)
print("Logistic Regression")
print(classification_report(y_test, lr_pred))
print(f'Accuracy: {accuracy_score(y_test, lr_pred)}')
Logistic Regression
precision recall f1-score support
0 0.99 0.98 0.98 2431
1 0.98 0.99 0.98 2433
accuracy 0.98 4864
macro avg 0.98 0.98 0.98 4864
weighted avg 0.98 0.98 0.98 4864
Accuracy: 0.9845805921052632
Decision Tree¶
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import classification_report, accuracy_score
# Train Decision Tree
dt_model = DecisionTreeClassifier(random_state=42)
dt_model.fit(X_train, y_train)
# Predict and evaluate
dt_pred = dt_model.predict(X_test)
print("Decision Tree Classifier")
print(classification_report(y_test, dt_pred))
print(f'Accuracy: {accuracy_score(y_test, dt_pred)}')
Decision Tree Classifier
precision recall f1-score support
0 0.98 0.99 0.99 2431
1 0.99 0.98 0.99 2433
accuracy 0.99 4864
macro avg 0.99 0.99 0.99 4864
weighted avg 0.99 0.99 0.99 4864
Accuracy: 0.9876644736842105
Random Forest¶
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import GridSearchCV
# Define an expanded parameter grid
param_grid_rf = {
'n_estimators': [200, 300],
'max_features': ['sqrt', 'log2'],
'max_depth': [8, 10],
'criterion': ['gini', 'entropy']
}
# Initialize the model
rf_model = RandomForestClassifier(random_state=42)
# Perform Grid Search
grid_search_rf = GridSearchCV(estimator=rf_model, param_grid=param_grid_rf, cv=5, n_jobs=-1, verbose=2)
grid_search_rf.fit(X_train, y_train)
# Best parameters and score
print(f"Best parameters for Random Forest: {grid_search_rf.best_params_}")
print(f"Best score for Random Forest: {grid_search_rf.best_score_}")
# Predict and evaluate
rf_best_model = grid_search_rf.best_estimator_
rf_pred = rf_best_model.predict(X_test)
print("Random Forest Classifier after Hyperparameter Tuning")
print(classification_report(y_test, rf_pred))
print(f'Accuracy: {accuracy_score(y_test, rf_pred)}')
Fitting 5 folds for each of 16 candidates, totalling 80 fits
Best parameters for Random Forest: {'criterion': 'gini', 'max_depth': 10, 'max_features': 'log2', 'n_estimators': 300}
Best score for Random Forest: 0.9949105542485164
Random Forest Classifier after Hyperparameter Tuning
precision recall f1-score support
0 1.00 0.99 1.00 2431
1 0.99 1.00 1.00 2433
accuracy 1.00 4864
macro avg 1.00 1.00 1.00 4864
weighted avg 1.00 1.00 1.00 4864
Accuracy: 0.9954769736842105
Light GBM¶
import lightgbm as lgb
lgb_model = lgb.LGBMClassifier(random_state=42)
lgb_model.fit(X_train, y_train)
# Predict and evaluate
lgb_pred = lgb_model.predict(X_test)
print("LightGBM Classifier")
print(classification_report(y_test, lgb_pred))
print(f'Accuracy: {accuracy_score(y_test, lgb_pred)}')
[LightGBM] [Info] Number of positive: 9725, number of negative: 9727 [LightGBM] [Info] Auto-choosing row-wise multi-threading, the overhead of testing was 0.003679 seconds. You can set `force_row_wise=true` to remove the overhead. And if memory is not enough, you can set `force_col_wise=true`. [LightGBM] [Info] Total Bins 663 [LightGBM] [Info] Number of data points in the train set: 19452, number of used features: 49 [LightGBM] [Info] [binary:BoostFromScore]: pavg=0.499949 -> initscore=-0.000206 [LightGBM] [Info] Start training from score -0.000206
LightGBM Classifier
precision recall f1-score support
0 1.00 0.99 1.00 2431
1 0.99 1.00 1.00 2433
accuracy 1.00 4864
macro avg 1.00 1.00 1.00 4864
weighted avg 1.00 1.00 1.00 4864
Accuracy: 0.9954769736842105
XGBoost¶
from sklearn.model_selection import GridSearchCV
import xgboost as xgb
# Define parameter grid
param_grid_xgb = {
'n_estimators': [200,400],
'max_depth': [8,10, 12],
'learning_rate': [0.05, 0.1, 0.2],
}
# Initialize the model
xgb_model = xgb.XGBClassifier(random_state=42)
# Perform Grid Search
grid_search_xgb = GridSearchCV(estimator=xgb_model, param_grid=param_grid_xgb, cv=5, n_jobs=-1, verbose=2)
grid_search_xgb.fit(X_train, y_train)
# Best parameters and score
print(f"Best parameters for XGBoost: {grid_search_xgb.best_params_}")
print(f"Best score for XGBoost: {grid_search_xgb.best_score_}")
# Predict and evaluate
xgb_best_model = grid_search_xgb.best_estimator_
xgb_pred = xgb_best_model.predict(X_test)
print("XGBoost Classifier after Hyperparameter Tuning")
print(classification_report(y_test, xgb_pred))
print(f'Accuracy: {accuracy_score(y_test, xgb_pred)}')
Fitting 5 folds for each of 18 candidates, totalling 90 fits
Best parameters for XGBoost: {'learning_rate': 0.05, 'max_depth': 8, 'n_estimators': 200}
Best score for XGBoost: 0.9957330706481702
XGBoost Classifier after Hyperparameter Tuning
precision recall f1-score support
0 1.00 0.99 1.00 2431
1 0.99 1.00 1.00 2433
accuracy 1.00 4864
macro avg 1.00 1.00 1.00 4864
weighted avg 1.00 1.00 1.00 4864
Accuracy: 0.99609375
Neural Network¶
!pip install keras-tuner > /dev/null 2>&1
import keras_tuner as kt
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Dropout
from tensorflow.keras.utils import to_categorical
from sklearn.metrics import classification_report, accuracy_score
# Convert target labels to one-hot encoding
y_train_onehot = to_categorical(y_train)
y_test_onehot = to_categorical(y_test)
# Build the neural network model
model = Sequential()
model.add(Dense(128, input_dim=X_train.shape[1], activation='relu'))
model.add(Dropout(0.4))
model.add(Dense(64, activation='relu'))
model.add(Dropout(0.3))
model.add(Dense(32, activation='relu'))
model.add(Dropout(0.2))
model.add(Dense(y_train_onehot.shape[1], activation='softmax'))
# Compile the model
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
# Train the model
model.fit(X_train, y_train_onehot, epochs=150, batch_size=32, validation_split=0.2, verbose=2)
# Predict and evaluate
y_pred_proba = model.predict(X_test)
y_pred = y_pred_proba.argmax(axis=1)
print("Neural Network Classifier")
print(classification_report(y_test, y_pred))
print(f'Accuracy: {accuracy_score(y_test, y_pred)}')
Epoch 1/150
/opt/anaconda3/lib/python3.11/site-packages/keras/src/layers/core/dense.py:87: UserWarning: Do not pass an `input_shape`/`input_dim` argument to a layer. When using Sequential models, prefer using an `Input(shape)` object as the first layer in the model instead. super().__init__(activity_regularizer=activity_regularizer, **kwargs)
487/487 - 1s - 2ms/step - accuracy: 0.9088 - loss: 0.2071 - val_accuracy: 0.9900 - val_loss: 0.0426
Epoch 2/150
487/487 - 0s - 671us/step - accuracy: 0.9817 - loss: 0.0660 - val_accuracy: 0.9946 - val_loss: 0.0313
Epoch 3/150
487/487 - 0s - 725us/step - accuracy: 0.9876 - loss: 0.0507 - val_accuracy: 0.9951 - val_loss: 0.0288
Epoch 4/150
487/487 - 0s - 697us/step - accuracy: 0.9897 - loss: 0.0410 - val_accuracy: 0.9954 - val_loss: 0.0279
Epoch 5/150
487/487 - 0s - 667us/step - accuracy: 0.9906 - loss: 0.0393 - val_accuracy: 0.9951 - val_loss: 0.0285
Epoch 6/150
487/487 - 0s - 670us/step - accuracy: 0.9911 - loss: 0.0354 - val_accuracy: 0.9954 - val_loss: 0.0287
Epoch 7/150
487/487 - 0s - 666us/step - accuracy: 0.9918 - loss: 0.0347 - val_accuracy: 0.9946 - val_loss: 0.0284
Epoch 8/150
487/487 - 0s - 781us/step - accuracy: 0.9924 - loss: 0.0330 - val_accuracy: 0.9951 - val_loss: 0.0296
Epoch 9/150
487/487 - 0s - 664us/step - accuracy: 0.9918 - loss: 0.0307 - val_accuracy: 0.9954 - val_loss: 0.0284
Epoch 10/150
487/487 - 0s - 661us/step - accuracy: 0.9935 - loss: 0.0288 - val_accuracy: 0.9954 - val_loss: 0.0280
Epoch 11/150
487/487 - 0s - 659us/step - accuracy: 0.9933 - loss: 0.0282 - val_accuracy: 0.9949 - val_loss: 0.0280
Epoch 12/150
487/487 - 0s - 677us/step - accuracy: 0.9933 - loss: 0.0259 - val_accuracy: 0.9949 - val_loss: 0.0302
Epoch 13/150
487/487 - 0s - 688us/step - accuracy: 0.9936 - loss: 0.0231 - val_accuracy: 0.9949 - val_loss: 0.0327
Epoch 14/150
487/487 - 0s - 722us/step - accuracy: 0.9941 - loss: 0.0241 - val_accuracy: 0.9951 - val_loss: 0.0304
Epoch 15/150
487/487 - 0s - 716us/step - accuracy: 0.9929 - loss: 0.0233 - val_accuracy: 0.9954 - val_loss: 0.0326
Epoch 16/150
487/487 - 0s - 667us/step - accuracy: 0.9942 - loss: 0.0226 - val_accuracy: 0.9949 - val_loss: 0.0329
Epoch 17/150
487/487 - 0s - 650us/step - accuracy: 0.9938 - loss: 0.0211 - val_accuracy: 0.9951 - val_loss: 0.0308
Epoch 18/150
487/487 - 0s - 656us/step - accuracy: 0.9940 - loss: 0.0217 - val_accuracy: 0.9949 - val_loss: 0.0314
Epoch 19/150
487/487 - 0s - 661us/step - accuracy: 0.9940 - loss: 0.0212 - val_accuracy: 0.9943 - val_loss: 0.0362
Epoch 20/150
487/487 - 0s - 664us/step - accuracy: 0.9937 - loss: 0.0214 - val_accuracy: 0.9951 - val_loss: 0.0334
Epoch 21/150
487/487 - 0s - 686us/step - accuracy: 0.9935 - loss: 0.0217 - val_accuracy: 0.9951 - val_loss: 0.0337
Epoch 22/150
487/487 - 0s - 670us/step - accuracy: 0.9936 - loss: 0.0203 - val_accuracy: 0.9954 - val_loss: 0.0330
Epoch 23/150
487/487 - 0s - 673us/step - accuracy: 0.9940 - loss: 0.0192 - val_accuracy: 0.9951 - val_loss: 0.0393
Epoch 24/150
487/487 - 0s - 661us/step - accuracy: 0.9944 - loss: 0.0180 - val_accuracy: 0.9951 - val_loss: 0.0405
Epoch 25/150
487/487 - 0s - 721us/step - accuracy: 0.9942 - loss: 0.0184 - val_accuracy: 0.9946 - val_loss: 0.0417
Epoch 26/150
487/487 - 0s - 651us/step - accuracy: 0.9947 - loss: 0.0157 - val_accuracy: 0.9946 - val_loss: 0.0474
Epoch 27/150
487/487 - 0s - 695us/step - accuracy: 0.9946 - loss: 0.0160 - val_accuracy: 0.9949 - val_loss: 0.0427
Epoch 28/150
487/487 - 0s - 661us/step - accuracy: 0.9944 - loss: 0.0171 - val_accuracy: 0.9943 - val_loss: 0.0433
Epoch 29/150
487/487 - 0s - 661us/step - accuracy: 0.9948 - loss: 0.0154 - val_accuracy: 0.9951 - val_loss: 0.0496
Epoch 30/150
487/487 - 0s - 658us/step - accuracy: 0.9950 - loss: 0.0135 - val_accuracy: 0.9941 - val_loss: 0.0470
Epoch 31/150
487/487 - 0s - 671us/step - accuracy: 0.9947 - loss: 0.0146 - val_accuracy: 0.9949 - val_loss: 0.0507
Epoch 32/150
487/487 - 0s - 692us/step - accuracy: 0.9943 - loss: 0.0153 - val_accuracy: 0.9951 - val_loss: 0.0466
Epoch 33/150
487/487 - 0s - 680us/step - accuracy: 0.9956 - loss: 0.0129 - val_accuracy: 0.9949 - val_loss: 0.0437
Epoch 34/150
487/487 - 0s - 692us/step - accuracy: 0.9951 - loss: 0.0131 - val_accuracy: 0.9949 - val_loss: 0.0439
Epoch 35/150
487/487 - 0s - 689us/step - accuracy: 0.9956 - loss: 0.0136 - val_accuracy: 0.9943 - val_loss: 0.0466
Epoch 36/150
487/487 - 0s - 683us/step - accuracy: 0.9955 - loss: 0.0140 - val_accuracy: 0.9954 - val_loss: 0.0491
Epoch 37/150
487/487 - 0s - 668us/step - accuracy: 0.9958 - loss: 0.0127 - val_accuracy: 0.9954 - val_loss: 0.0536
Epoch 38/150
487/487 - 0s - 661us/step - accuracy: 0.9954 - loss: 0.0141 - val_accuracy: 0.9954 - val_loss: 0.0462
Epoch 39/150
487/487 - 0s - 660us/step - accuracy: 0.9955 - loss: 0.0123 - val_accuracy: 0.9954 - val_loss: 0.0550
Epoch 40/150
487/487 - 0s - 658us/step - accuracy: 0.9963 - loss: 0.0125 - val_accuracy: 0.9949 - val_loss: 0.0500
Epoch 41/150
487/487 - 0s - 657us/step - accuracy: 0.9967 - loss: 0.0100 - val_accuracy: 0.9943 - val_loss: 0.0652
Epoch 42/150
487/487 - 0s - 657us/step - accuracy: 0.9961 - loss: 0.0116 - val_accuracy: 0.9951 - val_loss: 0.0587
Epoch 43/150
487/487 - 0s - 655us/step - accuracy: 0.9969 - loss: 0.0086 - val_accuracy: 0.9954 - val_loss: 0.0607
Epoch 44/150
487/487 - 0s - 659us/step - accuracy: 0.9950 - loss: 0.0136 - val_accuracy: 0.9951 - val_loss: 0.0486
Epoch 45/150
487/487 - 0s - 659us/step - accuracy: 0.9954 - loss: 0.0110 - val_accuracy: 0.9949 - val_loss: 0.0671
Epoch 46/150
487/487 - 0s - 659us/step - accuracy: 0.9961 - loss: 0.0105 - val_accuracy: 0.9949 - val_loss: 0.0716
Epoch 47/150
487/487 - 0s - 659us/step - accuracy: 0.9961 - loss: 0.0126 - val_accuracy: 0.9946 - val_loss: 0.0556
Epoch 48/150
487/487 - 0s - 658us/step - accuracy: 0.9964 - loss: 0.0104 - val_accuracy: 0.9954 - val_loss: 0.0633
Epoch 49/150
487/487 - 0s - 657us/step - accuracy: 0.9958 - loss: 0.0117 - val_accuracy: 0.9951 - val_loss: 0.0689
Epoch 50/150
487/487 - 0s - 658us/step - accuracy: 0.9969 - loss: 0.0082 - val_accuracy: 0.9951 - val_loss: 0.0722
Epoch 51/150
487/487 - 0s - 684us/step - accuracy: 0.9963 - loss: 0.0088 - val_accuracy: 0.9949 - val_loss: 0.0730
Epoch 52/150
487/487 - 0s - 753us/step - accuracy: 0.9961 - loss: 0.0117 - val_accuracy: 0.9951 - val_loss: 0.0764
Epoch 53/150
487/487 - 0s - 668us/step - accuracy: 0.9962 - loss: 0.0094 - val_accuracy: 0.9951 - val_loss: 0.0779
Epoch 54/150
487/487 - 0s - 662us/step - accuracy: 0.9972 - loss: 0.0067 - val_accuracy: 0.9951 - val_loss: 0.0894
Epoch 55/150
487/487 - 0s - 667us/step - accuracy: 0.9973 - loss: 0.0081 - val_accuracy: 0.9954 - val_loss: 0.0804
Epoch 56/150
487/487 - 0s - 668us/step - accuracy: 0.9970 - loss: 0.0092 - val_accuracy: 0.9949 - val_loss: 0.0684
Epoch 57/150
487/487 - 0s - 659us/step - accuracy: 0.9974 - loss: 0.0077 - val_accuracy: 0.9946 - val_loss: 0.0732
Epoch 58/150
487/487 - 0s - 661us/step - accuracy: 0.9966 - loss: 0.0085 - val_accuracy: 0.9949 - val_loss: 0.0827
Epoch 59/150
487/487 - 0s - 660us/step - accuracy: 0.9973 - loss: 0.0081 - val_accuracy: 0.9954 - val_loss: 0.0843
Epoch 60/150
487/487 - 0s - 660us/step - accuracy: 0.9977 - loss: 0.0064 - val_accuracy: 0.9949 - val_loss: 0.0871
Epoch 61/150
487/487 - 0s - 660us/step - accuracy: 0.9978 - loss: 0.0071 - val_accuracy: 0.9949 - val_loss: 0.0893
Epoch 62/150
487/487 - 0s - 661us/step - accuracy: 0.9969 - loss: 0.0078 - val_accuracy: 0.9951 - val_loss: 0.0903
Epoch 63/150
487/487 - 0s - 659us/step - accuracy: 0.9971 - loss: 0.0073 - val_accuracy: 0.9951 - val_loss: 0.0931
Epoch 64/150
487/487 - 0s - 787us/step - accuracy: 0.9977 - loss: 0.0072 - val_accuracy: 0.9949 - val_loss: 0.0749
Epoch 65/150
487/487 - 0s - 664us/step - accuracy: 0.9975 - loss: 0.0076 - val_accuracy: 0.9938 - val_loss: 0.0944
Epoch 66/150
487/487 - 0s - 662us/step - accuracy: 0.9969 - loss: 0.0081 - val_accuracy: 0.9943 - val_loss: 0.0905
Epoch 67/150
487/487 - 0s - 664us/step - accuracy: 0.9976 - loss: 0.0070 - val_accuracy: 0.9943 - val_loss: 0.0877
Epoch 68/150
487/487 - 0s - 658us/step - accuracy: 0.9982 - loss: 0.0056 - val_accuracy: 0.9951 - val_loss: 0.0977
Epoch 69/150
487/487 - 0s - 657us/step - accuracy: 0.9974 - loss: 0.0081 - val_accuracy: 0.9946 - val_loss: 0.0923
Epoch 70/150
487/487 - 0s - 657us/step - accuracy: 0.9974 - loss: 0.0076 - val_accuracy: 0.9943 - val_loss: 0.0993
Epoch 71/150
487/487 - 0s - 656us/step - accuracy: 0.9978 - loss: 0.0078 - val_accuracy: 0.9943 - val_loss: 0.0809
Epoch 72/150
487/487 - 0s - 657us/step - accuracy: 0.9980 - loss: 0.0061 - val_accuracy: 0.9943 - val_loss: 0.0849
Epoch 73/150
487/487 - 0s - 654us/step - accuracy: 0.9980 - loss: 0.0062 - val_accuracy: 0.9951 - val_loss: 0.0820
Epoch 74/150
487/487 - 0s - 673us/step - accuracy: 0.9976 - loss: 0.0072 - val_accuracy: 0.9943 - val_loss: 0.0798
Epoch 75/150
487/487 - 0s - 668us/step - accuracy: 0.9983 - loss: 0.0065 - val_accuracy: 0.9951 - val_loss: 0.0974
Epoch 76/150
487/487 - 0s - 673us/step - accuracy: 0.9977 - loss: 0.0067 - val_accuracy: 0.9949 - val_loss: 0.0842
Epoch 77/150
487/487 - 0s - 667us/step - accuracy: 0.9976 - loss: 0.0064 - val_accuracy: 0.9949 - val_loss: 0.0994
Epoch 78/150
487/487 - 0s - 705us/step - accuracy: 0.9978 - loss: 0.0089 - val_accuracy: 0.9949 - val_loss: 0.0772
Epoch 79/150
487/487 - 0s - 665us/step - accuracy: 0.9976 - loss: 0.0066 - val_accuracy: 0.9949 - val_loss: 0.0858
Epoch 80/150
487/487 - 0s - 662us/step - accuracy: 0.9979 - loss: 0.0064 - val_accuracy: 0.9949 - val_loss: 0.0937
Epoch 81/150
487/487 - 0s - 657us/step - accuracy: 0.9977 - loss: 0.0069 - val_accuracy: 0.9946 - val_loss: 0.0857
Epoch 82/150
487/487 - 0s - 661us/step - accuracy: 0.9976 - loss: 0.0068 - val_accuracy: 0.9946 - val_loss: 0.0982
Epoch 83/150
487/487 - 0s - 656us/step - accuracy: 0.9978 - loss: 0.0067 - val_accuracy: 0.9949 - val_loss: 0.1018
Epoch 84/150
487/487 - 0s - 658us/step - accuracy: 0.9978 - loss: 0.0066 - val_accuracy: 0.9943 - val_loss: 0.0980
Epoch 85/150
487/487 - 0s - 659us/step - accuracy: 0.9983 - loss: 0.0050 - val_accuracy: 0.9949 - val_loss: 0.1247
Epoch 86/150
487/487 - 0s - 671us/step - accuracy: 0.9978 - loss: 0.0061 - val_accuracy: 0.9938 - val_loss: 0.1019
Epoch 87/150
487/487 - 0s - 701us/step - accuracy: 0.9976 - loss: 0.0062 - val_accuracy: 0.9941 - val_loss: 0.1017
Epoch 88/150
487/487 - 0s - 657us/step - accuracy: 0.9985 - loss: 0.0050 - val_accuracy: 0.9943 - val_loss: 0.1177
Epoch 89/150
487/487 - 0s - 683us/step - accuracy: 0.9986 - loss: 0.0036 - val_accuracy: 0.9946 - val_loss: 0.1231
Epoch 90/150
487/487 - 0s - 678us/step - accuracy: 0.9982 - loss: 0.0055 - val_accuracy: 0.9946 - val_loss: 0.1002
Epoch 91/150
487/487 - 0s - 666us/step - accuracy: 0.9977 - loss: 0.0075 - val_accuracy: 0.9943 - val_loss: 0.0891
Epoch 92/150
487/487 - 0s - 674us/step - accuracy: 0.9981 - loss: 0.0053 - val_accuracy: 0.9943 - val_loss: 0.0965
Epoch 93/150
487/487 - 0s - 670us/step - accuracy: 0.9985 - loss: 0.0050 - val_accuracy: 0.9941 - val_loss: 0.1101
Epoch 94/150
487/487 - 0s - 676us/step - accuracy: 0.9980 - loss: 0.0047 - val_accuracy: 0.9946 - val_loss: 0.1283
Epoch 95/150
487/487 - 0s - 668us/step - accuracy: 0.9983 - loss: 0.0070 - val_accuracy: 0.9941 - val_loss: 0.0866
Epoch 96/150
487/487 - 0s - 662us/step - accuracy: 0.9983 - loss: 0.0052 - val_accuracy: 0.9946 - val_loss: 0.1078
Epoch 97/150
487/487 - 0s - 661us/step - accuracy: 0.9980 - loss: 0.0054 - val_accuracy: 0.9936 - val_loss: 0.0996
Epoch 98/150
487/487 - 0s - 707us/step - accuracy: 0.9977 - loss: 0.0052 - val_accuracy: 0.9951 - val_loss: 0.1268
Epoch 99/150
487/487 - 0s - 660us/step - accuracy: 0.9981 - loss: 0.0061 - val_accuracy: 0.9951 - val_loss: 0.1103
Epoch 100/150
487/487 - 0s - 713us/step - accuracy: 0.9981 - loss: 0.0059 - val_accuracy: 0.9946 - val_loss: 0.1065
Epoch 101/150
487/487 - 0s - 701us/step - accuracy: 0.9983 - loss: 0.0045 - val_accuracy: 0.9946 - val_loss: 0.1165
Epoch 102/150
487/487 - 0s - 673us/step - accuracy: 0.9981 - loss: 0.0065 - val_accuracy: 0.9946 - val_loss: 0.1111
Epoch 103/150
487/487 - 0s - 674us/step - accuracy: 0.9983 - loss: 0.0060 - val_accuracy: 0.9946 - val_loss: 0.1060
Epoch 104/150
487/487 - 0s - 675us/step - accuracy: 0.9985 - loss: 0.0042 - val_accuracy: 0.9933 - val_loss: 0.1103
Epoch 105/150
487/487 - 0s - 762us/step - accuracy: 0.9985 - loss: 0.0042 - val_accuracy: 0.9941 - val_loss: 0.1294
Epoch 106/150
487/487 - 0s - 684us/step - accuracy: 0.9982 - loss: 0.0040 - val_accuracy: 0.9938 - val_loss: 0.1279
Epoch 107/150
487/487 - 0s - 668us/step - accuracy: 0.9986 - loss: 0.0063 - val_accuracy: 0.9946 - val_loss: 0.1137
Epoch 108/150
487/487 - 0s - 694us/step - accuracy: 0.9985 - loss: 0.0044 - val_accuracy: 0.9946 - val_loss: 0.1299
Epoch 109/150
487/487 - 0s - 778us/step - accuracy: 0.9984 - loss: 0.0050 - val_accuracy: 0.9951 - val_loss: 0.1242
Epoch 110/150
487/487 - 0s - 658us/step - accuracy: 0.9981 - loss: 0.0055 - val_accuracy: 0.9943 - val_loss: 0.1142
Epoch 111/150
487/487 - 0s - 727us/step - accuracy: 0.9983 - loss: 0.0055 - val_accuracy: 0.9949 - val_loss: 0.1065
Epoch 112/150
487/487 - 0s - 666us/step - accuracy: 0.9983 - loss: 0.0052 - val_accuracy: 0.9949 - val_loss: 0.1106
Epoch 113/150
487/487 - 0s - 664us/step - accuracy: 0.9987 - loss: 0.0030 - val_accuracy: 0.9949 - val_loss: 0.1417
Epoch 114/150
487/487 - 0s - 756us/step - accuracy: 0.9986 - loss: 0.0041 - val_accuracy: 0.9946 - val_loss: 0.1513
Epoch 115/150
487/487 - 0s - 669us/step - accuracy: 0.9985 - loss: 0.0040 - val_accuracy: 0.9943 - val_loss: 0.1573
Epoch 116/150
487/487 - 0s - 661us/step - accuracy: 0.9987 - loss: 0.0041 - val_accuracy: 0.9946 - val_loss: 0.1327
Epoch 117/150
487/487 - 0s - 657us/step - accuracy: 0.9988 - loss: 0.0043 - val_accuracy: 0.9946 - val_loss: 0.1422
Epoch 118/150
487/487 - 0s - 686us/step - accuracy: 0.9983 - loss: 0.0057 - val_accuracy: 0.9949 - val_loss: 0.1336
Epoch 119/150
487/487 - 0s - 674us/step - accuracy: 0.9987 - loss: 0.0058 - val_accuracy: 0.9946 - val_loss: 0.1266
Epoch 120/150
487/487 - 0s - 661us/step - accuracy: 0.9985 - loss: 0.0061 - val_accuracy: 0.9946 - val_loss: 0.1225
Epoch 121/150
487/487 - 0s - 669us/step - accuracy: 0.9987 - loss: 0.0042 - val_accuracy: 0.9938 - val_loss: 0.1308
Epoch 122/150
487/487 - 0s - 684us/step - accuracy: 0.9983 - loss: 0.0061 - val_accuracy: 0.9943 - val_loss: 0.1079
Epoch 123/150
487/487 - 0s - 674us/step - accuracy: 0.9984 - loss: 0.0047 - val_accuracy: 0.9946 - val_loss: 0.1047
Epoch 124/150
487/487 - 0s - 693us/step - accuracy: 0.9985 - loss: 0.0044 - val_accuracy: 0.9946 - val_loss: 0.1273
Epoch 125/150
487/487 - 0s - 716us/step - accuracy: 0.9990 - loss: 0.0036 - val_accuracy: 0.9951 - val_loss: 0.1308
Epoch 126/150
487/487 - 0s - 698us/step - accuracy: 0.9987 - loss: 0.0038 - val_accuracy: 0.9954 - val_loss: 0.1269
Epoch 127/150
487/487 - 0s - 707us/step - accuracy: 0.9988 - loss: 0.0035 - val_accuracy: 0.9951 - val_loss: 0.1332
Epoch 128/150
487/487 - 0s - 709us/step - accuracy: 0.9987 - loss: 0.0034 - val_accuracy: 0.9946 - val_loss: 0.1406
Epoch 129/150
487/487 - 0s - 697us/step - accuracy: 0.9990 - loss: 0.0031 - val_accuracy: 0.9951 - val_loss: 0.1509
Epoch 130/150
487/487 - 0s - 686us/step - accuracy: 0.9986 - loss: 0.0060 - val_accuracy: 0.9946 - val_loss: 0.1045
Epoch 131/150
487/487 - 0s - 723us/step - accuracy: 0.9987 - loss: 0.0033 - val_accuracy: 0.9941 - val_loss: 0.1225
Epoch 132/150
487/487 - 0s - 713us/step - accuracy: 0.9987 - loss: 0.0046 - val_accuracy: 0.9943 - val_loss: 0.1236
Epoch 133/150
487/487 - 0s - 701us/step - accuracy: 0.9992 - loss: 0.0033 - val_accuracy: 0.9946 - val_loss: 0.1280
Epoch 134/150
487/487 - 0s - 695us/step - accuracy: 0.9986 - loss: 0.0051 - val_accuracy: 0.9943 - val_loss: 0.1099
Epoch 135/150
487/487 - 0s - 694us/step - accuracy: 0.9991 - loss: 0.0037 - val_accuracy: 0.9946 - val_loss: 0.1250
Epoch 136/150
487/487 - 0s - 680us/step - accuracy: 0.9991 - loss: 0.0025 - val_accuracy: 0.9951 - val_loss: 0.1569
Epoch 137/150
487/487 - 0s - 749us/step - accuracy: 0.9984 - loss: 0.0037 - val_accuracy: 0.9946 - val_loss: 0.1087
Epoch 138/150
487/487 - 0s - 691us/step - accuracy: 0.9987 - loss: 0.0054 - val_accuracy: 0.9941 - val_loss: 0.1028
Epoch 139/150
487/487 - 0s - 661us/step - accuracy: 0.9986 - loss: 0.0042 - val_accuracy: 0.9951 - val_loss: 0.1152
Epoch 140/150
487/487 - 0s - 685us/step - accuracy: 0.9989 - loss: 0.0032 - val_accuracy: 0.9946 - val_loss: 0.1283
Epoch 141/150
487/487 - 0s - 695us/step - accuracy: 0.9985 - loss: 0.0042 - val_accuracy: 0.9946 - val_loss: 0.1271
Epoch 142/150
487/487 - 0s - 668us/step - accuracy: 0.9985 - loss: 0.0041 - val_accuracy: 0.9946 - val_loss: 0.1323
Epoch 143/150
487/487 - 0s - 673us/step - accuracy: 0.9994 - loss: 0.0025 - val_accuracy: 0.9951 - val_loss: 0.1529
Epoch 144/150
487/487 - 0s - 659us/step - accuracy: 0.9990 - loss: 0.0031 - val_accuracy: 0.9949 - val_loss: 0.1813
Epoch 145/150
487/487 - 0s - 658us/step - accuracy: 0.9989 - loss: 0.0031 - val_accuracy: 0.9946 - val_loss: 0.1785
Epoch 146/150
487/487 - 0s - 659us/step - accuracy: 0.9990 - loss: 0.0031 - val_accuracy: 0.9943 - val_loss: 0.1484
Epoch 147/150
487/487 - 0s - 697us/step - accuracy: 0.9988 - loss: 0.0034 - val_accuracy: 0.9941 - val_loss: 0.1666
Epoch 148/150
487/487 - 0s - 723us/step - accuracy: 0.9984 - loss: 0.0043 - val_accuracy: 0.9951 - val_loss: 0.1564
Epoch 149/150
487/487 - 0s - 675us/step - accuracy: 0.9983 - loss: 0.0048 - val_accuracy: 0.9949 - val_loss: 0.1471
Epoch 150/150
487/487 - 0s - 660us/step - accuracy: 0.9983 - loss: 0.0066 - val_accuracy: 0.9951 - val_loss: 0.1093
1/152 ━━━━━━━━━━━━━━━━━━━━ 4s 30ms/step
152/152 ━━━━━━━━━━━━━━━━━━━━ 0s 319us/step
Neural Network Classifier
precision recall f1-score support
0 1.00 0.99 0.99 2431
1 0.99 1.00 0.99 2433
accuracy 0.99 4864
macro avg 0.99 0.99 0.99 4864
weighted avg 0.99 0.99 0.99 4864
Accuracy: 0.9925986842105263
Plots for Receiver Operating Characteristic (ROC)¶
import matplotlib.pyplot as plt
from sklearn.metrics import roc_curve, auc
def plot_multiple_roc_curves(y_test, y_pred_probas, model_names):
plt.figure(figsize=(12, 8))
for y_pred_proba, model_name in zip(y_pred_probas, model_names):
fpr, tpr, _ = roc_curve(y_test, y_pred_proba)
roc_auc = auc(fpr, tpr)
plt.plot(fpr, tpr, lw=2, label=f'{model_name} (AUC = {roc_auc:.2f})')
plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Receiver Operating Characteristic')
plt.legend(loc="lower right")
plt.show()
y_pred_probas = [nb_pred, lr_pred, dt_pred, rf_pred, lgb_pred, xgb_pred, y_pred]
model_names = ['Naive Bayes', 'Logistic Regression', 'Decision Tree', 'Random Forest', 'LightGBM', 'XGBoost', 'Neural Network']
# Plot ROC curves
plot_multiple_roc_curves(y_test, y_pred_probas, model_names)
The ROC curve plot shows that models are generally performing exceptionally well, with AUC values very close to 1.00. This suggests that data is well-separated, and these models are effective in distinguishing between the classes. Naive Bayes, while still good, has relatively lower performance, possibly due to its assumption of feature independence, which might not hold in dataset.
SHAP values¶
!pip install shap > /dev/null 2>&1
!pip install shap xgboost > /dev/null 2>&1
import shap
explainer = shap.TreeExplainer(xgb_best_model)
X_test_df = X_test if isinstance(X_test, pd.DataFrame) else pd.DataFrame(X_test)
shap_values = explainer.shap_values(X_test_df)
shap_values = shap_values if isinstance(shap_values, list) else [shap_values]
shap.summary_plot(shap_values[0], X_test_df)
# Bar plot for feature importance
shap.summary_plot(shap_values[0], X_test_df, plot_type="bar")
# Force plot for individual prediction
shap.initjs()
shap.force_plot(explainer.expected_value, shap_values[0][0].reshape(1, -1), X_test_df.iloc[0].values.reshape(1, -1))
actual_feature_name = X_test_df.columns[0] # Replace with the actual feature name you want to plot
# Dependence plot
shap.dependence_plot(actual_feature_name, shap_values[0], X_test_df)
indices_of_interest = [1, 27, 28]
X_selected.columns[indices_of_interest]
Index(['Number_of_casualties',
'Lanes_or_Medians_Two-way (divided with broken lines road marking)',
'Lanes_or_Medians_Undivided Two way'],
dtype='object')
1. SHAP Summary Plot (Beeswarm Plot)¶
Insights:
- Feature Importance: The features are ranked by their importance to the model. The higher a feature is on the plot, the more important it is.
- Impact on Model Output: The color represents the feature value (red for high, blue for low). The spread along the x-axis shows the impact of the feature on the prediction. Points farther from the center (0) have a larger impact.
- Interaction: Overlapping points indicate interactions between features.
- Feature 1 (likely
Number_of_casualties) is the most important, with a high spread, indicating significant influence on the model output. - Features 28 (
Lanes_or_Medians_One way) and 27 (Lanes_or_Medians_Two-way (divided with broken lines road marking)) also have high importance.
2. SHAP Bar Plot (Feature Importance)¶
Insights:
- Feature Importance: Confirms the importance of each feature in a more straightforward way. The x-axis represents the average impact on the model output magnitude.
- Top Features: Features 1, 28, and 27 are the top three features, confirming their significant role in the model.
3. SHAP Dependence Plot¶
- Feature 0 Impact: This plot shows how
Feature 0impacts the model predictions.
To summarize:¶
Top Influential Features:
Number_of_casualties(Feature 1)Lanes_or_Medians_One way(Feature 28)Lanes_or_Medians_Two-way (divided with broken lines road marking)(Feature 27)
Impact on Model Output:
- These features have significant impacts on the predictions. For instance, a higher number of casualties strongly influences the model's prediction.
- The SHAP values provide a measure of this impact, with larger values indicating a greater effect.
Feature Interactions:
- The dependence plot can reveal interactions between features. For example, the interaction between
Feature 0and other features (represented by color).
- The dependence plot can reveal interactions between features. For example, the interaction between
Model Interpretation:
- These insights help to understand which features the model relies on most for making predictions. This can be valuable for model debugging, feature engineering, and communicating model behavior to stakeholders.
By using SHAP values and these plots, one can gain a deeper understanding of model's decision-making process, which helps in improving the model and ensuring its transparency.